/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the
 * "License"). Please refer to the License for details. You may not use this
 * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN
 * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS
 * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository
 * for the full text of the License.
 */

#ifndef EXAMPLES_SORT_SORT_CUSTOM_H
#define EXAMPLES_SORT_SORT_CUSTOM_H
#include "kernel_operator.h"

namespace MyCustomKernel {
struct VecTiling {
    uint32_t elementCount;
};

constexpr uint8_t ELEMENT_16 = 16;
constexpr uint8_t ELEMENT_32 = 32;
constexpr uint8_t BUFFER_SIZE = 9;
constexpr uint8_t DOUBLE_SIZE = 2;
constexpr uint8_t LOCAL_SIZE_220_HALF = 4;
constexpr uint8_t LOCAL_SIZE_220_FLOAT = 2;
constexpr uint8_t LOCAL_SIZE_200 = 8;

template <typename T, bool isFullSort>
class KernelSort {
public:
    __aicore__ inline KernelSort() {}
    __aicore__ inline void Init(GM_ADDR dstValueGm, GM_ADDR dstIndexGm, GM_ADDR srcValueGm, GM_ADDR srcIndexGm,
                                VecTiling tilingData)
    {
        m_elementCount = tilingData.elementCount;
        m_concatRepeatTimes = m_elementCount / ELEMENT_16;
#if __CCE_AICORE__ == 220
        m_sortRepeatTimes = m_elementCount / ELEMENT_32;
        m_extractRepeatTimes = m_elementCount / ELEMENT_32;
#elif __CCE_AICORE__ <= 200
        m_sortRepeatTimes = m_elementCount / ELEMENT_16;
        m_extractRepeatTimes = m_elementCount / ELEMENT_16;
#endif
        m_inBufferSize = m_elementCount * sizeof(uint32_t);
        m_outBufferSize = m_elementCount * sizeof(uint32_t);
        m_calcBufferSize = m_elementCount * BUFFER_SIZE;
        m_tmpBufferSize = m_elementCount * BUFFER_SIZE;

        m_valueGlobal.SetGlobalBuffer((__gm__ T*)srcValueGm);
        m_indexGlobal.SetGlobalBuffer((__gm__ uint32_t*)srcIndexGm);
        m_dstValueGlobal.SetGlobalBuffer((__gm__ T*)dstValueGm);
        m_dstIndexGlobal.SetGlobalBuffer((__gm__ uint32_t*)dstIndexGm);
        m_pipe.InitBuffer(m_queIn, DOUBLE_SIZE, m_inBufferSize);
        m_pipe.InitBuffer(m_queOut, 1, m_outBufferSize);
        m_pipe.InitBuffer(m_queOutIdx, 1, m_outBufferSize);
        m_pipe.InitBuffer(m_queCalc, 1, m_calcBufferSize*sizeof(T));
        m_pipe.InitBuffer(m_queTmp, 1, m_tmpBufferSize*sizeof(T));
        m_pipe.InitBuffer(m_queTmpConcat, 1, m_tmpBufferSize*sizeof(T));
    }
    __aicore__ inline void Process() {
        CopyIn();
        Compute();
        CopyOut();
    }

private:
    __aicore__ inline void CopyIn() {
        AscendC::LocalTensor<T> valueLocal = m_queIn.AllocTensor<T>();
        AscendC::DataCopy(valueLocal, m_valueGlobal, m_elementCount);
        m_queIn.EnQue(valueLocal);

        AscendC::LocalTensor<uint32_t> indexLocal = m_queIn.AllocTensor<uint32_t>();
        AscendC::DataCopy(indexLocal, m_indexGlobal, m_elementCount);
        m_queIn.EnQue(indexLocal);
    }
    __aicore__ inline void Compute() {
        AscendC::LocalTensor<T> valueLocal = m_queIn.DeQue<T>();
        AscendC::LocalTensor<uint32_t> indexLocal = m_queIn.DeQue<uint32_t>();
        AscendC::LocalTensor<T> sortedLocal = m_queCalc.AllocTensor<T>();
        AscendC::LocalTensor<T> concatTmpLocal = m_queTmpConcat.AllocTensor<T>();
        AscendC::LocalTensor<T> sortTmpLocal = m_queTmp.AllocTensor<T>();
        AscendC::LocalTensor<T> dstValueLocal = m_queOut.AllocTensor<T>();
        AscendC::LocalTensor<uint32_t> dstIndexLocal = m_queOutIdx.AllocTensor<uint32_t>();
        AscendC::LocalTensor<T> concatLocal;
        AscendC::Concat(concatLocal, valueLocal, concatTmpLocal, m_concatRepeatTimes);
        valueLocal.SetSize(m_elementCount);
#if __CCE_AICORE__ == 220
        if (sizeof(T) == sizeof(half)) {
            sortedLocal.SetSize(m_elementCount * LOCAL_SIZE_220_HALF);
            sortTmpLocal.SetSize(m_elementCount * LOCAL_SIZE_220_HALF);
        } else {
            sortedLocal.SetSize(m_elementCount * LOCAL_SIZE_220_FLOAT);
            sortTmpLocal.SetSize(m_elementCount * LOCAL_SIZE_220_FLOAT);
        }
#elif __CCE_AICORE__ <= 200
        sortedLocal.SetSize(m_elementCount * LOCAL_SIZE_200);
        sortTmpLocal.SetSize(m_elementCount * LOCAL_SIZE_200);
#endif
        AscendC::Sort<T, isFullSort>(sortedLocal, concatLocal, indexLocal, sortTmpLocal, m_sortRepeatTimes);

        AscendC::Extract(dstValueLocal, dstIndexLocal, sortedLocal, m_extractRepeatTimes);
        m_queTmp.FreeTensor(sortTmpLocal);
        m_queTmpConcat.FreeTensor(concatTmpLocal);
        m_queIn.FreeTensor(valueLocal);
        m_queIn.FreeTensor(indexLocal);
        m_queCalc.FreeTensor(sortedLocal);
        m_queOut.EnQue(dstValueLocal);
        m_queOutIdx.EnQue(dstIndexLocal);
    }
    __aicore__ inline void CopyOut() {
        AscendC::LocalTensor<T> dstValueLocal = m_queOut.DeQue<T>();
        AscendC::LocalTensor<uint32_t> dstIndexLocal = m_queOutIdx.DeQue<uint32_t>();
        AscendC::DataCopy(m_dstValueGlobal, dstValueLocal, m_elementCount);
        AscendC::DataCopy(m_dstIndexGlobal, dstIndexLocal, m_elementCount);
        m_queOut.FreeTensor(dstValueLocal);
        m_queOutIdx.FreeTensor(dstIndexLocal);
    }

private:
    AscendC::TPipe m_pipe;
    AscendC::TQue<AscendC::TPosition::VECIN, DOUBLE_SIZE> m_queIn;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> m_queOut;
    AscendC::TQue<AscendC::TPosition::VECOUT, 1> m_queOutIdx;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> m_queTmp;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> m_queTmpConcat;
    AscendC::TQue<AscendC::TPosition::VECIN, 1> m_queCalc;
    AscendC::GlobalTensor<T> m_valueGlobal;
    AscendC::GlobalTensor<uint32_t> m_indexGlobal;
    AscendC::GlobalTensor<T> m_dstValueGlobal;
    AscendC::GlobalTensor<uint32_t> m_dstIndexGlobal;
    uint32_t m_elementCount = 64;
    uint32_t m_concatRepeatTimes;
    uint32_t m_sortRepeatTimes;
    uint32_t m_extractRepeatTimes;
    uint32_t m_inBufferSize;
    uint32_t m_outBufferSize;
    uint32_t m_calcBufferSize;
    uint32_t m_tmpBufferSize;
};

} // namespace MyCustomKernel

#endif // EXAMPLES_SORT_SORT_CUSTOM_H
